# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import logging  # allow-direct-logging
import os

import httpx
import llama_stack_client
import openai
import pytest

from llama_stack.core.datatypes import AuthenticationRequiredError
from llama_stack.core.library_client import LlamaStackAsLibraryClient
from tests.common.mcp import dependency_tools, make_mcp_server

from .fixtures.test_cases import (
    custom_tool_test_cases,
    file_search_test_cases,
    mcp_tool_test_cases,
    multi_turn_tool_execution_streaming_test_cases,
    multi_turn_tool_execution_test_cases,
    web_search_test_cases,
)
from .helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment
from .streaming_assertions import StreamingValidator


@pytest.mark.parametrize("case", web_search_test_cases)
def test_response_non_streaming_web_search(responses_client, text_model_id, case):
    response = responses_client.responses.create(
        model=text_model_id,
        input=case.input,
        tools=case.tools,
        stream=False,
    )
    assert len(response.output) > 1
    assert response.output[0].type == "web_search_call"
    assert response.output[0].status == "completed"
    assert response.output[1].type == "message"
    assert response.output[1].status == "completed"
    assert response.output[1].role == "assistant"
    assert len(response.output[1].content) > 0
    assert case.expected.lower() in response.output_text.lower().strip()


@pytest.mark.parametrize("case", file_search_test_cases)
def test_response_non_streaming_file_search(
    responses_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path, case
):
    vector_store = new_vector_store(responses_client, "test_vector_store", embedding_model_id, embedding_dimension)

    if case.file_content:
        file_name = "test_response_non_streaming_file_search.txt"
        file_path = tmp_path / file_name
        file_path.write_text(case.file_content)
    elif case.file_path:
        file_path = os.path.join(os.path.dirname(__file__), "fixtures", case.file_path)
        file_name = os.path.basename(file_path)
    else:
        raise ValueError("No file content or path provided for case")

    file_response = upload_file(responses_client, file_name, file_path)

    # Attach our file to the vector store
    responses_client.vector_stores.files.create(
        vector_store_id=vector_store.id,
        file_id=file_response.id,
    )

    # Wait for the file to be attached
    wait_for_file_attachment(responses_client, vector_store.id, file_response.id)

    # Update our tools with the right vector store id
    tools = case.tools
    for tool in tools:
        if tool["type"] == "file_search":
            tool["vector_store_ids"] = [vector_store.id]

    # Create the response request, which should query our vector store
    response = responses_client.responses.create(
        model=text_model_id,
        input=case.input,
        tools=tools,
        stream=False,
        include=["file_search_call.results"],
    )

    # Verify the file_search_tool was called
    assert len(response.output) > 1
    assert response.output[0].type == "file_search_call"
    assert response.output[0].status == "completed"
    assert response.output[0].queries  # ensure it's some non-empty list
    assert response.output[0].results
    assert case.expected.lower() in response.output[0].results[0].text.lower()
    assert response.output[0].results[0].score > 0

    # Verify the output_text generated by the response
    assert case.expected.lower() in response.output_text.lower().strip()


def test_response_non_streaming_file_search_empty_vector_store(
    responses_client, text_model_id, embedding_model_id, embedding_dimension
):
    vector_store = new_vector_store(responses_client, "test_vector_store", embedding_model_id, embedding_dimension)

    # Create the response request, which should query our vector store
    response = responses_client.responses.create(
        model=text_model_id,
        input="How many experts does the Llama 4 Maverick model have?",
        tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}],
        stream=False,
        include=["file_search_call.results"],
    )

    # Verify the file_search_tool was called
    assert len(response.output) > 1
    assert response.output[0].type == "file_search_call"
    assert response.output[0].status == "completed"
    assert response.output[0].queries  # ensure it's some non-empty list
    assert not response.output[0].results  # ensure we don't get any results

    # Verify some output_text was generated by the response
    assert response.output_text


def test_response_sequential_file_search(
    responses_client, text_model_id, embedding_model_id, embedding_dimension, tmp_path
):
    """Test file search with sequential responses using previous_response_id."""
    vector_store = new_vector_store(responses_client, "test_vector_store", embedding_model_id, embedding_dimension)

    # Create a test file with content
    file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
    file_name = "test_sequential_file_search.txt"
    file_path = tmp_path / file_name
    file_path.write_text(file_content)

    file_response = upload_file(responses_client, file_name, file_path)

    # Attach the file to the vector store
    responses_client.vector_stores.files.create(
        vector_store_id=vector_store.id,
        file_id=file_response.id,
    )

    # Wait for the file to be attached
    wait_for_file_attachment(responses_client, vector_store.id, file_response.id)

    tools = [{"type": "file_search", "vector_store_ids": [vector_store.id]}]

    # First response request with file search
    response = responses_client.responses.create(
        model=text_model_id,
        input="How many experts does the Llama 4 Maverick model have?",
        tools=tools,
        stream=False,
        include=["file_search_call.results"],
    )

    # Verify the file_search_tool was called
    assert len(response.output) > 1
    assert response.output[0].type == "file_search_call"
    assert response.output[0].status == "completed"
    assert response.output[0].queries
    assert response.output[0].results
    assert "128" in response.output_text or "experts" in response.output_text.lower()

    # Second response request using previous_response_id
    response2 = responses_client.responses.create(
        model=text_model_id,
        input="Can you tell me more about the architecture?",
        tools=tools,
        stream=False,
        previous_response_id=response.id,
        include=["file_search_call.results"],
    )

    # Verify the second response has output
    assert len(response2.output) >= 1
    assert response2.output_text

    # The second response should maintain context from the first
    final_message = [output for output in response2.output if output.type == "message"]
    assert len(final_message) >= 1
    assert final_message[-1].role == "assistant"
    assert final_message[-1].status == "completed"


@pytest.mark.parametrize("case", mcp_tool_test_cases)
def test_response_non_streaming_mcp_tool(responses_client, text_model_id, case, caplog):
    with make_mcp_server() as mcp_server_info:
        tools = setup_mcp_tools(case.tools, mcp_server_info)

        response = responses_client.responses.create(
            model=text_model_id,
            input=case.input,
            tools=tools,
            stream=False,
        )

        assert len(response.output) >= 3
        list_tools = response.output[0]
        assert list_tools.type == "mcp_list_tools"
        assert list_tools.server_label == "localmcp"
        assert len(list_tools.tools) == 2
        assert {t.name for t in list_tools.tools} == {
            "get_boiling_point",
            "greet_everyone",
        }

        call = response.output[1]
        assert call.type == "mcp_call"
        assert call.name == "get_boiling_point"
        assert json.loads(call.arguments) == {
            "liquid_name": "myawesomeliquid",
            "celsius": True,
        }
        assert call.error is None
        assert "-100" in call.output

        # sometimes the model will call the tool again, so we need to get the last message
        message = response.output[-1]
        text_content = message.content[0].text
        assert "boiling point" in text_content.lower()

    with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
        tools = setup_mcp_tools(case.tools, mcp_server_info)

        exc_type = (
            AuthenticationRequiredError
            if isinstance(responses_client, LlamaStackAsLibraryClient)
            else (httpx.HTTPStatusError, openai.AuthenticationError, llama_stack_client.AuthenticationError)
        )
        # Suppress expected auth error logs only for the failing auth attempt
        with caplog.at_level(
            logging.CRITICAL, logger="llama_stack.providers.inline.agents.meta_reference.responses.streaming"
        ):
            with pytest.raises(exc_type):
                responses_client.responses.create(
                    model=text_model_id,
                    input=case.input,
                    tools=tools,
                    stream=False,
                )

        for tool in tools:
            if tool["type"] == "mcp":
                tool["authorization"] = "test-token"

        response = responses_client.responses.create(
            model=text_model_id,
            input=case.input,
            tools=tools,
            stream=False,
        )
        assert len(response.output) >= 3


@pytest.mark.parametrize("case", mcp_tool_test_cases)
def test_response_sequential_mcp_tool(responses_client, text_model_id, case):
    with make_mcp_server() as mcp_server_info:
        tools = setup_mcp_tools(case.tools, mcp_server_info)

        response = responses_client.responses.create(
            model=text_model_id,
            input=case.input,
            tools=tools,
            stream=False,
        )

        assert len(response.output) >= 3
        list_tools = response.output[0]
        assert list_tools.type == "mcp_list_tools"
        assert list_tools.server_label == "localmcp"
        assert len(list_tools.tools) == 2
        assert {t.name for t in list_tools.tools} == {
            "get_boiling_point",
            "greet_everyone",
        }

        call = response.output[1]
        assert call.type == "mcp_call"
        assert call.name == "get_boiling_point"
        assert json.loads(call.arguments) == {
            "liquid_name": "myawesomeliquid",
            "celsius": True,
        }
        assert call.error is None
        assert "-100" in call.output

        # sometimes the model will call the tool again, so we need to get the last message
        message = response.output[-1]
        text_content = message.content[0].text
        assert "boiling point" in text_content.lower()

        response2 = responses_client.responses.create(
            model=text_model_id, input=case.input, tools=tools, stream=False, previous_response_id=response.id
        )

        assert len(response2.output) >= 1
        message = response2.output[-1]
        text_content = message.content[0].text
        assert "boiling point" in text_content.lower()


@pytest.mark.parametrize("case", mcp_tool_test_cases)
@pytest.mark.parametrize("approve", [True, False])
def test_response_mcp_tool_approval(responses_client, text_model_id, case, approve):
    with make_mcp_server() as mcp_server_info:
        tools = setup_mcp_tools(case.tools, mcp_server_info)
        for tool in tools:
            tool["require_approval"] = "always"

        response = responses_client.responses.create(
            model=text_model_id,
            input=case.input,
            tools=tools,
            stream=False,
        )

        assert len(response.output) >= 2
        list_tools = response.output[0]
        assert list_tools.type == "mcp_list_tools"
        assert list_tools.server_label == "localmcp"
        assert len(list_tools.tools) == 2
        assert {t.name for t in list_tools.tools} == {
            "get_boiling_point",
            "greet_everyone",
        }

        approval_request = response.output[1]
        assert approval_request.type == "mcp_approval_request"
        assert approval_request.name == "get_boiling_point"
        args = json.loads(approval_request.arguments)
        assert args["liquid_name"] == "myawesomeliquid"
        # celsius has a default value of True, so it may be omitted or explicitly set
        assert args.get("celsius", True) is True

        # send approval response
        response = responses_client.responses.create(
            previous_response_id=response.id,
            model=text_model_id,
            input=[{"type": "mcp_approval_response", "approval_request_id": approval_request.id, "approve": approve}],
            tools=tools,
            stream=False,
        )

        if approve:
            assert len(response.output) >= 3
            list_tools = response.output[0]
            assert list_tools.type == "mcp_list_tools"
            assert list_tools.server_label == "localmcp"
            assert len(list_tools.tools) == 2
            assert {t.name for t in list_tools.tools} == {
                "get_boiling_point",
                "greet_everyone",
            }

            call = response.output[1]
            assert call.type == "mcp_call"
            assert call.name == "get_boiling_point"
            assert json.loads(call.arguments) == {
                "liquid_name": "myawesomeliquid",
                "celsius": True,
            }
            assert call.error is None
            assert "-100" in call.output

            # sometimes the model will call the tool again, so we need to get the last message
            message = response.output[-1]
            text_content = message.content[0].text
            assert "boiling point" in text_content.lower()
        else:
            assert len(response.output) >= 1
            for output in response.output:
                assert output.type != "mcp_call"


@pytest.mark.parametrize("case", custom_tool_test_cases)
def test_response_non_streaming_custom_tool(responses_client, text_model_id, case):
    response = responses_client.responses.create(
        model=text_model_id,
        input=case.input,
        tools=case.tools,
        stream=False,
    )
    assert len(response.output) == 1
    assert response.output[0].type == "function_call"
    assert response.output[0].status == "completed"
    assert response.output[0].name == "get_weather"


@pytest.mark.parametrize("case", custom_tool_test_cases)
def test_response_function_call_ordering_1(responses_client, text_model_id, case):
    response = responses_client.responses.create(
        model=text_model_id,
        input=case.input,
        tools=case.tools,
        stream=False,
    )
    assert len(response.output) == 1
    assert response.output[0].type == "function_call"
    assert response.output[0].status == "completed"
    assert response.output[0].name == "get_weather"
    inputs = []
    inputs.append(
        {
            "role": "user",
            "content": case.input,
        }
    )
    inputs.append(
        {
            "type": "function_call_output",
            "output": "It is raining.",
            "call_id": response.output[0].call_id,
        }
    )
    response = responses_client.responses.create(
        model=text_model_id, input=inputs, tools=case.tools, stream=False, previous_response_id=response.id
    )
    assert len(response.output) == 1


def test_response_function_call_ordering_2(responses_client, text_model_id):
    tools = [
        {
            "type": "function",
            "name": "get_weather",
            "description": "Get current temperature for a given location.",
            "parameters": {
                "additionalProperties": False,
                "properties": {
                    "location": {
                        "description": "City and country e.g. Bogotá, Colombia",
                        "type": "string",
                    }
                },
                "required": ["location"],
                "type": "object",
            },
        }
    ]
    inputs = [
        {
            "role": "user",
            "content": "Is the weather better in San Francisco or Los Angeles?",
        }
    ]
    response = responses_client.responses.create(
        model=text_model_id,
        input=inputs,
        tools=tools,
        stream=False,
    )
    for output in response.output:
        if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
            inputs.append(output)
    for output in response.output:
        if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
            weather = "It is raining."
            if "Los Angeles" in output.arguments:
                weather = "It is cloudy."
            inputs.append(
                {
                    "type": "function_call_output",
                    "output": weather,
                    "call_id": output.call_id,
                }
            )
    response = responses_client.responses.create(
        model=text_model_id,
        input=inputs,
        tools=tools,
        stream=False,
    )
    assert len(response.output) == 1
    assert "Los Angeles" in response.output_text


@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
def test_response_non_streaming_multi_turn_tool_execution(responses_client, text_model_id, case):
    """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
    with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
        tools = setup_mcp_tools(case.tools, mcp_server_info)

        response = responses_client.responses.create(
            input=case.input,
            model=text_model_id,
            tools=tools,
        )

        # Verify we have MCP tool calls in the output
        mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
        mcp_calls = [output for output in response.output if output.type == "mcp_call"]
        message_outputs = [output for output in response.output if output.type == "message"]

        # Should have exactly 1 MCP list tools message (at the beginning)
        assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
        assert mcp_list_tools[0].server_label == "localmcp"
        assert len(mcp_list_tools[0].tools) == 5  # Updated for dependency tools
        expected_tool_names = {
            "get_user_id",
            "get_user_permissions",
            "check_file_access",
            "get_experiment_id",
            "get_experiment_results",
        }
        assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names

        assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
        for mcp_call in mcp_calls:
            assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"

        assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"

        final_message = message_outputs[-1]
        assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
        assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
        assert len(final_message.content) > 0, "Final message should have content"

        expected_output = case.expected
        assert expected_output.lower() in response.output_text.lower(), (
            f"Expected '{expected_output}' to appear in response: {response.output_text}"
        )


@pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases)
def test_response_streaming_multi_turn_tool_execution(responses_client, text_model_id, case):
    """Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
    with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
        tools = setup_mcp_tools(case.tools, mcp_server_info)

        stream = responses_client.responses.create(
            input=case.input,
            model=text_model_id,
            tools=tools,
            stream=True,
        )

        chunks = []
        for chunk in stream:
            chunks.append(chunk)

        # Use validator for common streaming checks
        validator = StreamingValidator(chunks)
        validator.assert_basic_event_sequence()
        validator.assert_response_consistency()
        validator.assert_has_tool_calls()
        validator.assert_has_mcp_events()
        validator.assert_rich_streaming()

        # Get the final response from the last chunk
        final_chunk = chunks[-1]
        if hasattr(final_chunk, "response"):
            final_response = final_chunk.response

            # Verify multi-turn MCP tool execution results
            mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
            mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
            message_outputs = [output for output in final_response.output if output.type == "message"]

            # Should have exactly 1 MCP list tools message (at the beginning)
            assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
            assert mcp_list_tools[0].server_label == "localmcp"
            assert len(mcp_list_tools[0].tools) == 5  # Updated for dependency tools
            expected_tool_names = {
                "get_user_id",
                "get_user_permissions",
                "check_file_access",
                "get_experiment_id",
                "get_experiment_results",
            }
            assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names

            # Should have at least 1 MCP call (the model should call at least one tool)
            assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"

            # All MCP calls should be completed (verifies our tool execution works)
            for mcp_call in mcp_calls:
                assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"

            # Should have at least one final message response
            assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"

            # Final message should be from assistant and completed
            final_message = message_outputs[-1]
            assert final_message.role == "assistant", (
                f"Final message should be from assistant, got {final_message.role}"
            )
            assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
            assert len(final_message.content) > 0, "Final message should have content"

            # Check that the expected output appears in the response
            expected_output = case.expected
            assert expected_output.lower() in final_response.output_text.lower(), (
                f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
            )


def test_max_tool_calls_with_function_tools(responses_client, text_model_id):
    """Test handling of max_tool_calls with function tools in responses."""

    max_tool_calls = 1
    tools = [
        {
            "type": "function",
            "name": "get_weather",
            "description": "Get weather information for a specified location",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city name (e.g., 'New York', 'London')",
                    },
                },
            },
        },
        {
            "type": "function",
            "name": "get_time",
            "description": "Get current time for a specified location",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city name (e.g., 'New York', 'London')",
                    },
                },
            },
        },
    ]

    response = responses_client.responses.create(
        model=text_model_id,
        input="Can you tell me the weather in Paris and the current time?",
        tools=tools,
        stream=False,
        max_tool_calls=max_tool_calls,
    )

    # Verify we got two function calls and that the max_tool_calls does not affect function tools
    assert len(response.output) == 2
    assert response.output[0].type == "function_call"
    assert response.output[0].name == "get_weather"
    assert response.output[0].status == "completed"
    assert response.output[1].type == "function_call"
    assert response.output[1].name == "get_time"
    assert response.output[1].status == "completed"

    # Verify we have a valid max_tool_calls field
    assert response.max_tool_calls == max_tool_calls


def test_max_tool_calls_invalid(responses_client, text_model_id):
    """Test handling of invalid max_tool_calls in responses."""

    input = "Search for today's top technology news."
    invalid_max_tool_calls = 0
    tools = [
        {"type": "web_search"},
    ]

    # Create a response with an invalid max_tool_calls value i.e. 0
    # Handle ValueError from LLS and BadRequestError from OpenAI client
    with pytest.raises((ValueError, llama_stack_client.BadRequestError, openai.BadRequestError)) as excinfo:
        responses_client.responses.create(
            model=text_model_id,
            input=input,
            tools=tools,
            stream=False,
            max_tool_calls=invalid_max_tool_calls,
        )

    error_message = str(excinfo.value)
    assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
        f"Expected error message about invalid max_tool_calls, got: {error_message}"
    )


def test_max_tool_calls_with_mcp_tools(responses_client, text_model_id):
    """Test handling of max_tool_calls with mcp tools in responses."""

    with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
        input = "Get the experiment ID for 'boiling_point' and get the user ID for 'charlie'"
        max_tool_calls = [1, 5]
        tools = [
            {"type": "mcp", "server_label": "localmcp", "server_url": mcp_server_info["server_url"]},
        ]

        # First create a response that triggers mcp tools without max_tool_calls
        response = responses_client.responses.create(
            model=text_model_id,
            input=input,
            tools=tools,
            stream=False,
        )

        # Verify we got two mcp tool calls followed by a message
        assert len(response.output) == 4
        mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
        mcp_calls = [output for output in response.output if output.type == "mcp_call"]
        message_outputs = [output for output in response.output if output.type == "message"]
        assert len(mcp_list_tools) == 1
        assert len(mcp_calls) == 2, f"Expected two mcp calls, got {len(mcp_calls)}"
        assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"

        # Next create a response that triggers mcp tools with max_tool_calls set to 1
        response_2 = responses_client.responses.create(
            model=text_model_id,
            input=input,
            tools=tools,
            stream=False,
            max_tool_calls=max_tool_calls[0],
        )

        # Verify we got one mcp tool call followed by a message
        assert len(response_2.output) == 3
        mcp_list_tools = [output for output in response_2.output if output.type == "mcp_list_tools"]
        mcp_calls = [output for output in response_2.output if output.type == "mcp_call"]
        message_outputs = [output for output in response_2.output if output.type == "message"]
        assert len(mcp_list_tools) == 1
        assert len(mcp_calls) == 1, f"Expected one mcp call, got {len(mcp_calls)}"
        assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"

        # Verify we have a valid max_tool_calls field
        assert response_2.max_tool_calls == max_tool_calls[0]

        # Finally create a response that triggers mcp tools with max_tool_calls set to 5
        response_3 = responses_client.responses.create(
            model=text_model_id,
            input=input,
            tools=tools,
            stream=False,
            max_tool_calls=max_tool_calls[1],
        )

        # Verify we got two mcp tool calls followed by a message
        assert len(response_3.output) == 4
        mcp_list_tools = [output for output in response_3.output if output.type == "mcp_list_tools"]
        mcp_calls = [output for output in response_3.output if output.type == "mcp_call"]
        message_outputs = [output for output in response_3.output if output.type == "message"]
        assert len(mcp_list_tools) == 1
        assert len(mcp_calls) == 2, f"Expected two mcp calls, got {len(mcp_calls)}"
        assert len(message_outputs) == 1, f"Expected one message output, got {len(message_outputs)}"

        # Verify we have a valid max_tool_calls field
        assert response_3.max_tool_calls == max_tool_calls[1]
